#pragma once

#include "Common.hpp"
#include "Log.hpp"
#include "InetAddr.hpp"



namespace SocketModule
{
    using namespace LogModule;
    const static int gbacklog = 16;
    // 模板方法模式
    class Socket
    {
    protected:
        virtual ~Socket() {}
        virtual void SocketOrDie() = 0;
        virtual void BindOrDie(uint16_t port) = 0;
        virtual void ListenOrDie(int backlog) = 0;
        virtual std::shared_ptr<TcpSocket> Accept(InetAddr *client) = 0;
        virtual void Close() = 0;
        virtual int Recv(std::string *out) = 0;
        virtual int Send(const std::string &message) = 0;
        virtual int Connect(const std::string &server_ip, uint16_t port) = 0;

    public:
        void BuildTcpSocketMethod(uint16_t port, int backlog = gbacklog)
        {
            SocketOrDie();
            BindOrDie(port);
            ListenOrDie(backlog);
        }
        void BuildTcpClientSocketMethod()
        {
            SocketOrDie();
        }
    };

    const static int defaultfd = -1;

    class TcpSocket : public Socket
    {
    public:
        TcpSocket() // 无参构造listensockfd
            :_sockfd(defaultfd)
        {}

        // 将connect返回的文件描述符构造为套接字类型
        TcpSocket(int fd)
            :_sockfd(fd)
        {}

        		void SocketOrDie() override
        {
            _sockfd = ::socket(AF_INET, SOCK_STREAM, 0);
            if(_sockfd < 0)
            {
                LOG(LogLevel::FATAL) << "socket error";
                exit(SOCKET_ERR);
            }
            LOG(LogLevel::INFO) << "socket success";
        }

        void BindOrDie(uint16_t port) override
        {
            InetAddr local(port);
            int n  = ::bind(_sockfd, local.NetAddrPtr(), local.NetAddrLen());
            if(n < 0)
            {
                LOG(LogLevel::FATAL) << "bind error";
                exit(BIND_ERR);
            }
            LOG(LogLevel::INFO) << "bind success";
        }

        void ListenOrDie(int backlog) override
        {
            int n = ::listen(_sockfd, backlog);
            if (n < 0)
            {
                LOG(LogLevel::FATAL) << "listen error";
                exit(LISTEN_ERR);
            }
            LOG(LogLevel::INFO) << "listen success";
        }

        std::shared_ptr<TcpSocket> Accept(InetAddr *client) override
        {
            struct sockaddr_in peer;
            socklen_t len = sizeof(peer);
            int fd = ::accept(_sockfd, (struct sockaddr*)&peer, &len);
            if (fd < 0)
            {
                LOG(LogLevel::WARNING) << "accept warning ...";
                return nullptr;
            }
            client->SetAddr(peer);
            return std::make_shared<TcpSocket>(fd);
        }

        void Close() override
        {
            if (_sockfd >= 0)
                ::close(_sockfd);
        }

        int Recv(std::string *out) override
        {
            char buffer[1024];
            ssize_t n = ::recv(_sockfd, buffer, sizeof(buffer) - 1, 0);
            if (n > 0)
            {
                buffer[n] = 0;
                *out += buffer; // 特意+=
            }
            return n;
        }

        int Send(const std::string &message) override
        {
            return send(_sockfd, message.c_str(), message.size(), 0);
        }

        int Connect(const std::string &server_ip, uint16_t port) override
        {
            InetAddr server(server_ip, port);
            return ::connect(_sockfd, server.NetAddrPtr(), server.NetAddrLen());
        }


        ~TcpSocket() {}

        
    private:
        int _sockfd; // listensockfd, sockfd都可能
    };
}
